# -*- coding: utf-8 -*-
"""
Generating samples by linearly combining two input images.
"""
from __future__ import absolute_import, print_function, division

import numpy as np
import tensorflow as tf

from niftynet.engine.image_window_dataset import ImageWindowDataset
from niftynet.contrib.csv_reader.sampler_csv_rows import ImageWindowDatasetCSV
from niftynet.engine.image_window import N_SPATIAL, LOCATION_FORMAT


class LinearInterpolateSamplerCSV(ImageWindowDatasetCSV):
    """
    This class reads two feature vectors from files (often generated
    by running feature extractors on images in advance)
    and returns n linear combinations of the vectors.
    The coefficients are generated by::

        np.linspace(0, 1, n_interpolations)
    """

    def __init__(self,
                 reader,
                 csv_reader,
                 window_sizes,
                 batch_size=10,
                 n_interpolations=10,
                 queue_length=10,
                 name='linear_interpolation_sampler'):
        ImageWindowDatasetCSV.__init__(
            self,
            reader,
            csv_reader=csv_reader,
            window_sizes=window_sizes,
            batch_size=batch_size,
            queue_length=queue_length,
            shuffle=False,
            epoch=1,
            smaller_final_batch_mode='drop',
            name=name)
        self.n_interpolations = n_interpolations
        # only try to use the first spatial shape available
        image_spatial_shape = list(self.reader.shapes.values())[0][:3]
        self.window.set_spatial_shape(image_spatial_shape)
        tf.logging.info(
            "initialised linear interpolation sampler %s ", self.window.shapes)
        assert not self.window.has_dynamic_shapes, \
            "dynamic shapes not supported, please specify " \
            "spatial_window_size = (1, 1, 1)"

    def layer_op(self, *_unused_args, **_unused_kwargs):
        """
        This function first reads two vectors, and interpolates them
        with self.n_interpolations mixing coefficients.

        Location coordinates are set to ``np.ones`` for all the vectors.
        """
        output_dict = {}
        image_id_x, data_x, _ = self.reader(idx=None, shuffle=False)
        image_id_y, data_y, _ = self.reader(idx=None, shuffle=True)
        if not data_x or not data_y:
            return
        if image_id_x == image_id_y:
            while image_id_x == image_id_y:
                image_id_x, data_x, _ = self.reader(idx=None, shuffle=False)
                image_id_y, data_y, _ = self.reader(idx=None, shuffle=True)
                if not data_x or not data_y:
                    return
        embedding_x = data_x[self.window.names[0]]
        embedding_y = data_y[self.window.names[0]]

        steps = np.linspace(0, 1, self.n_interpolations)
        for (_, mixture) in enumerate(steps):
            output_vector = \
                embedding_x * mixture + embedding_y * (1 - mixture)
            coordinates = np.ones((1, N_SPATIAL * 2 + 1), dtype=np.int32)
            coordinates[0, 0:2] = [image_id_x, image_id_y]
            output_dict = {}
            for name in self.window.names:
                coordinates_key = LOCATION_FORMAT.format(name)
                image_data_key = name
                output_dict[coordinates_key] = coordinates
                output_dict[image_data_key] = output_vector[np.newaxis, ...]
        if self.csv_reader is not None:
            _, label_dict_x, _ = self.csv_reader(idx=image_id_x)
            _, label_dict_y, _ = self.csv_reader(idx=image_id_y)
            output_dict.update(label_dict_x)
            output_dict.update(label_dict_y)
            for name in self.csv_reader.names():
                    output_dict[name + '_location'] = output_dict[
                            'image_location']
        return output_dict
